Skip to content

feat(flash-attn): add Python DSL Flash Attention example under kernels/python/flash_atten#117

Open
chenshengxin2026 wants to merge 1 commit into
hw-native-sys:mainfrom
chenshengxin2026:feat/flash-attn-v1
Open

feat(flash-attn): add Python DSL Flash Attention example under kernels/python/flash_atten#117
chenshengxin2026 wants to merge 1 commit into
hw-native-sys:mainfrom
chenshengxin2026:feat/flash-attn-v1

Conversation

@chenshengxin2026
Copy link
Copy Markdown

@chenshengxin2026 chenshengxin2026 commented May 7, 2026

Summary

Add a Python DSL Flash Attention example under kernels/python/flash_atten/.

  • The example is developed on top of the in-tree manual kernel kernels/manual/common/flash_atten (same four-stage Cube/Vector pipeline compute_qk -> compute_p -> compute_pv -> compute_gu, same TILE_S1=256 / CUBE_S1=128 / QK_PRELOAD=4 shape and FIFO layout).
  • Design also references the Huawei CSL PTO DSL AOT Flash Attention 140 TFLOPS example: https://github.com/huawei-csl/pto-dsl/tree/main/examples/aot/flash_attention/140tflops (benchmark conventions, TFLOP/s accounting, case set).

Files

  • kernels/python/flash_atten/kernels/fa_builder.py — PTO Python DSL kernel builder (ptodsl).
  • kernels/python/flash_atten/caller.cpp — host shim exported as call_kernel for ctypes.
  • kernels/python/flash_atten/compile.shptoas + bisheng build pipeline (build_artifacts/fa.mlir, fa.cpp, fa.so).
  • kernels/python/flash_atten/run.py — Torch-NPU driver: build, correctness check vs FP32 reference / torch_npu.npu_fused_infer_attention_score, sweep case1..case8 (S0=S1 from 1024 to 131072), TFLOP/s report, TSV summary.
  • kernels/python/flash_atten/README.md, README_zh.md — usage, supported platform, build/run, custom cases.
  • .gitignore — ignore kernels/python/flash_atten/build_artifacts/.

Kernel scope

  • HEAD = 128, S0 = 128 per Q block, TILE_S1 = 256, CUBE_S1 = 128, QK_PRELOAD = 4, non-causal only.
  • Total Q rows configured by FA_Q_ROWS (multiple of 128); total KV rows supplied at runtime; each S1 must be compatible with S1_TILE=256 and QK_PRELOAD=4.

Build & run

source ${ASCEND_INSTALL_PATH}/bin/setenv.bash
cd ${repo}/kernels/python/flash_atten
export PTO_LIB_PATH=${repo}
python3 run.py                # full default suite case1..case8
python3 run.py --case case1   # single case
FA_Q_ROWS=2048 FA_BENCH_LENGTHS=1024,2048,4096 python3 run.py

Performance

Ascend 910B1, HEAD=128, S0=S1, non-causal. do_bench warmup=10, iter=100. fa_us is kernel-only host-side stream-event window; TF/s uses matmul + scale + softmax FLOP counts (same accounting as the 140 TFLOPS reference).

Python DSL vs Huawei CSL 140 TFLOPS reference

case S0=S1 140tflops fa_us 140 TF/s pto-isa fa_us pto-isa TF/s 140 / pto-isa time
case1 1024 23.65 22.97 18.92 28.71 1.25x
case2 2048 44.30 49.04 48.32 44.97 0.92x
case3 4096 130.21 66.74 206.22 42.14 0.63x
case4 8192 324.84 107.02 383.91 90.55 0.85x
case5 16384 1097.10 126.74 1038.72 133.87 1.06x
case6 32768 3819.87 145.61 3456.58 160.91 1.11x
case7 65536 14454.41 153.92 13527.26 164.47 1.07x
case8 131072 56926.69 156.33 54178.87 164.26 1.05x

Also vs torch_npu.npu_fused_infer_attention_score (multi-core fused kernel, upper-bound reference only)

case S0=S1 140_us pto-isa_us fused_us 140_TF/s pto-isa_TF/s fused_TF/s 140 / pto-isa time 140 speedup vs fused pto-isa speedup vs fused
case1 1024 23.65 18.92 55.92 22.97 28.71 9.71 1.25x 2.36x 2.96x
case2 2048 44.30 48.32 73.27 49.04 44.97 29.65 0.92x 1.65x 1.52x
case3 4096 130.21 206.22 137.52 66.74 42.14 63.19 0.63x 1.06x 0.67x
case4 8192 324.84 383.91 294.67 107.02 90.55 117.97 0.85x 0.91x 0.77x
case5 16384 1097.10 1038.72 855.44 126.74 133.87 162.55 1.06x 0.78x 0.82x
case6 32768 3819.87 3456.58 3093.62 145.61 160.91 179.79 1.11x 0.81x 0.89x
case7 65536 14454.41 13527.26 12093.48 153.92 164.47 183.97 1.07x 0.84x 0.89x
case8 131072 56926.69 54178.87 48203.22 156.33 164.26 184.62 1.05x 0.85x 0.89x

Headline:

  • Roughly on par with the Huawei CSL 140 TFLOPS reference across the full S1 sweep (8/8 cases within ~0.6x..1.25x), and faster at the extremes (case1, case5..case8).
  • Peak ~164 TF/s at long S1 (case6..case8), in the same band as the fused multi-core kernel.

Testing

  • Build: compile.sh produces build_artifacts/fa.so per FA_Q_ROWS.
  • Correctness: each case passes the in-bench comparison (FP32 torch reference when feasible, else torch_npu.npu_fused_infer_attention_score).
  • Benchmark: python3 run.py runs case1..case8 end-to-end and emits the per-case TFLOP/s and a TSV summary.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a pto-dsl version of Flash Attention, including the kernel builder, host shim, and benchmarking infrastructure. The review identified several high-priority issues: missing error handling for rtGetC2cCtrlAddr which could lead to null pointer dereferences, and buffer aliasing that undermines the software pipeline's performance and safety. Additionally, there is a discrepancy between the implemented tile size and the documentation, and several instances where tile.muls should be replaced with tile.mov for better efficiency.

Comment on lines +29 to +30
(void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen);
(void)fftsLen;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return value of rtGetC2cCtrlAddr is ignored. If this function fails, fftsAddr will remain nullptr, which will cause the kernel to crash or exhibit undefined behavior when pto.set_ffts is called inside the kernel. You should check the return code and handle potential errors appropriately.

Comment thread kernels/python/flash_atten-v1/kernels/fa_builder.py Outdated
Comment thread kernels/python/flash_atten-v1/kernels/fa_builder.py Outdated
Comment thread kernels/python/flash_atten-v1/kernels/fa_builder.py Outdated
Comment thread kernels/python/flash_atten-v1/kernels/fa_builder.py Outdated
@chenshengxin2026 chenshengxin2026 changed the title Add pto-dsl Python port of Flash Attention v1 perf kernel feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner May 7, 2026
@chenshengxin2026 chenshengxin2026 changed the title feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner [WIP] feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner May 7, 2026
@zhoubot
Copy link
Copy Markdown
Collaborator

zhoubot commented May 8, 2026

Triage review (2026-05-08): this PR is mergeable from a repository-state perspective: GitHub reports it as clean against main, and all four CI checks pass (Pre-commit, Docs build, CPU SIM smoke, CPU SIM full ST). I reviewed the changed file set and the runner/build flow at a high level.

Before merging, please resolve the readiness signal: the title still says [WIP], while the PR is not draft and CI is green. Either retitle it as ready or mark it draft if the known long-sequence performance gap is still a blocker. I would also prefer a small README/index entry for kernels/python/flash_atten-v1/, because the body contains important constraints (S1_TILE=32, QK_PRELOAD=2, causal unsupported, long-S1 gap) that will be hard to discover after merge.

No blocking conflict found in this triage pass; this still deserves normal owner review of the PTO-DSL kernel semantics and the NPU benchmark claims before merge.

@chenshengxin2026
Copy link
Copy Markdown
Author

Added kernels/python/flash_atten-v2/ (commit 35b35de) — manual-aligned variant at TILE_S1=256, CUBE_S1=128, kTileFactor=2, all three pipes on the address-based slot model from PTOAS PR #606. Single-call correctness PASSED at S1=1024 (max_err 4.43e-05) and S1=2048 (max_err 2.72e-05); S1>=4096 currently aicore-timeouts because ptoas emits TPipe<...,SlotNum=8,LocalSlotNum=8,...> for globaltensor pipe init instead of the manual's ...,8,2,.... Filed details + suggested fix as #118.

@chenshengxin2026 chenshengxin2026 changed the title [WIP] feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner [WIP] feat: Add PTO-DSL Flash Attention performance kernel with Python validation runner May 9, 2026
@chenshengxin2026 chenshengxin2026 changed the title [WIP] feat: Add PTO-DSL Flash Attention performance kernel with Python validation runner feat(flash-attn): add Python DSL Flash Attention example under kernels/python/flash_atten May 18, 2026
…s/python/flash_atten

Port the manual Flash Attention kernel under kernels/manual/common/flash_atten
to the PTO Python DSL (ptodsl) and add a build/run/benchmark entry point.

- kernels/python/flash_atten/kernels/fa_builder.py: PTO Python DSL builder
  for a four-stage Cube/Vector software pipeline
  (compute_qk -> compute_p -> compute_pv -> compute_gu) with TILE_S1=256,
  CUBE_S1=128, QK_PRELOAD=4, matching the manual kernel shape.
- kernels/python/flash_atten/caller.cpp: host shim exported as call_kernel
  for ctypes.
- kernels/python/flash_atten/compile.sh: ptoas + bisheng build pipeline,
  emits build_artifacts/fa.mlir, fa.cpp, fa.so.
- kernels/python/flash_atten/run.py: Torch-NPU driver with correctness check
  vs FP32 torch reference / npu_fused_infer_attention_score and a sweep
  over case1..case8 (S0=S1 from 1024 up to 131072), TFLOP/s report and
  TSV summary.
- kernels/python/flash_atten/README.md, README_zh.md: usage, supported
  platform, build/run, custom cases, output format.
- .gitignore: ignore kernels/python/flash_atten/build_artifacts/.

Design references:
- kernels/manual/common/flash_atten (in-tree manual kernel; pipeline,
  TILE_S1/CUBE_S1/QK_PRELOAD shape and FIFO layout)
- https://github.com/huawei-csl/pto-dsl/tree/main/examples/aot/flash_attention/140tflops
  (AOT 140 TFLOPS reference; benchmark conventions, TFLOP/s accounting)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants